package com.xlongwei.session;

import java.io.IOException;
import java.io.Serializable;
import java.lang.reflect.Type;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.UUID;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonDeserializationContext;
import com.google.gson.JsonDeserializer;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParseException;
import com.google.gson.JsonPrimitive;
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer;
import com.google.gson.reflect.TypeToken;
import com.xlongwei.session.JedisUtil.Callback;

import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisCluster;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;
import redis.clients.jedis.JedisSentinelPool;
import redis.clients.jedis.Protocol;

/**
 * @see https://gitee.com/darkidiot/distributedsession
 */
public class RedisSessionFilter implements Filter {
	private static Logger log = LoggerFactory.getLogger(RedisSessionFilter.class);
	static final String FULL_CLASS_NAME = "FULL_CLASS_NAME";
	static final String SESSION_CREATE_TIME = "SESSION_CREATE_TIME";
	static JedisPoolConfig poolConfig = new JedisPoolConfig();
	static GsonSerializer gson = new GsonSerializer();
	static JedisPool pool = null;
	static JedisCluster cluster = null;
	static String sessionCookieName, cookieDomain, cookieContextPath;
	static int cookieMaxAge, maxInactiveInterval;
	static String prefix = null;
	static String sentinelMaster = null;
	static JedisSentinelPool sentinel = null;
    
	public void init(FilterConfig filterConfig) throws ServletException {
		sessionCookieName = JedisUtil.firstNonBlank(filterConfig.getInitParameter("sessionCookieName"), "JSESSIONID");
		cookieDomain = JedisUtil.firstNonBlank(filterConfig.getInitParameter("cookieDomain"));
		cookieContextPath = JedisUtil.firstNonBlank(filterConfig.getInitParameter("cookieContextPath"), "/");
		cookieMaxAge = Integer.parseInt(JedisUtil.firstNonBlank(filterConfig.getInitParameter("cookieMaxAge"), "-1"));
		maxInactiveInterval = Integer.parseInt(JedisUtil.firstNonBlank(filterConfig.getInitParameter("maxInactiveInterval"), "1800"));
		prefix = JedisUtil.firstNonBlank(filterConfig.getInitParameter("prefix"));
		sentinelMaster = JedisUtil.firstNonBlank(filterConfig.getInitParameter("sentinelMaster"));
		poolConfig.setMaxTotal(Integer.parseInt(JedisUtil.firstNonBlank(filterConfig.getInitParameter("maxTotal"), "8")));
		poolConfig.setMaxIdle(Integer.parseInt(JedisUtil.firstNonBlank(filterConfig.getInitParameter("maxIdle"), "4")));
		poolConfig.setMinIdle(Integer.parseInt(JedisUtil.firstNonBlank(filterConfig.getInitParameter("minIdle"), "2")));
		poolConfig.setLifo(Boolean.parseBoolean(JedisUtil.firstNonBlank(filterConfig.getInitParameter("lilo"), "true")));
		poolConfig.setFairness(Boolean.parseBoolean(JedisUtil.firstNonBlank(filterConfig.getInitParameter("fairness"), "false")));
		poolConfig.setMaxWaitMillis(Long.parseLong(JedisUtil.firstNonBlank(filterConfig.getInitParameter("maxWaitMillis"), "2000")));
		poolConfig.setMinEvictableIdleTimeMillis(Long.parseLong(JedisUtil.firstNonBlank(filterConfig.getInitParameter("minEvictableIdleTimeMillis"), "1800000")));
		poolConfig.setSoftMinEvictableIdleTimeMillis(Long.parseLong(JedisUtil.firstNonBlank(filterConfig.getInitParameter("softMinEvictableIdleTimeMillis"), "1800000")));
		poolConfig.setNumTestsPerEvictionRun(Integer.parseInt(JedisUtil.firstNonBlank(filterConfig.getInitParameter("numTestsPerEvictionRun"), "3")));
		poolConfig.setEvictionPolicyClassName(JedisUtil.firstNonBlank(filterConfig.getInitParameter("evictionPolicyClassName"), "org.apache.commons.pool2.impl.DefaultEvictionPolicy"));
		poolConfig.setTestOnCreate(Boolean.parseBoolean(JedisUtil.firstNonBlank(filterConfig.getInitParameter("testOnCreate"), "false")));
		poolConfig.setTestOnBorrow(Boolean.parseBoolean(JedisUtil.firstNonBlank(filterConfig.getInitParameter("testOnBorrow"), "false")));
		poolConfig.setTestOnReturn(Boolean.parseBoolean(JedisUtil.firstNonBlank(filterConfig.getInitParameter("testOnReturn"), "false")));
		poolConfig.setTestWhileIdle(Boolean.parseBoolean(JedisUtil.firstNonBlank(filterConfig.getInitParameter("testWhileIdle"), "true")));
		poolConfig.setTimeBetweenEvictionRunsMillis(Long.parseLong(JedisUtil.firstNonBlank(filterConfig.getInitParameter("timeBetweenEvictionRunsMillis"), "-1")));
		poolConfig.setBlockWhenExhausted(Boolean.parseBoolean(JedisUtil.firstNonBlank(filterConfig.getInitParameter("blockWhenExhausted"), "true")));
		log.info("JedisPool maxTotal={}, maxIdle={}, minIdle={}, lifo={}", poolConfig.getMaxTotal(), poolConfig.getMaxIdle(), poolConfig.getMinIdle(), poolConfig.getLifo());
		int timeout = Integer.parseInt(JedisUtil.firstNonBlank(filterConfig.getInitParameter("timeout"), String.valueOf(Protocol.DEFAULT_TIMEOUT)));
		String password = filterConfig.getInitParameter("password");
		int database = Integer.parseInt(JedisUtil.firstNonBlank(filterConfig.getInitParameter("database"), String.valueOf(Protocol.DEFAULT_DATABASE)));
		Set<HostAndPort> nodes = JedisUtil.nodes(filterConfig.getInitParameter("hostAndPorts"));
		if(nodes!=null && nodes.size()>1) {
			if(JedisUtil.isBlank(sentinelMaster)) {
				cluster = new JedisCluster(nodes, poolConfig);
				log.info("JedisCluster {} nodes={}", nodes.size(), nodes);
			}else {
				Set<String> sentinels = JedisUtil.sentinels(nodes);
				log.info("JedisSentinel master={},timeout={},database={} {} sentinels={}", sentinelMaster, timeout, database, sentinels.size(), sentinels);
				sentinel = new JedisSentinelPool(sentinelMaster, sentinels, poolConfig, timeout, password, database);
			}
		}else {
			String host = JedisUtil.firstNonBlank(filterConfig.getInitParameter("host"), Protocol.DEFAULT_HOST);
			int port = Integer.parseInt(JedisUtil.firstNonBlank(filterConfig.getInitParameter("host"), String.valueOf(Protocol.DEFAULT_PORT)));
			pool = new JedisPool(poolConfig, host, port, timeout, password, database);
			log.warn("host={},port={},timeout={},database={},minIdle={},maxTotal={},lifo={}", host, port, timeout, database, poolConfig.getMinIdle(), poolConfig.getMaxTotal(), poolConfig.getLifo());
		}
	}
    
	public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
		if(request instanceof RequestWrapper) {
			chain.doFilter(request, response);
		}else {
			RequestWrapper requestWrapper = new RequestWrapper((HttpServletRequest)request, (HttpServletResponse)response);
			chain.doFilter(requestWrapper, response);
			requestWrapper.afterRequest();
		}
	}

	public void destroy() {
		JedisUtil.close(pool, sentinel, cluster);
	}
	
	public static class RequestWrapper extends HttpServletRequestWrapper {
		private RedisSession currentSession = null;
		final HttpServletResponse response;

		public RequestWrapper(HttpServletRequest request, HttpServletResponse response) {
			super(request);
			this.response = response;
		}
		
		public synchronized HttpSession getSession(boolean create) {
			if(currentSession == null && create) {
				currentSession = new RedisSession(this);
			}
			return currentSession;
		}

		public HttpSession getSession() {
			return getSession(true);
		}

		public void afterRequest() {
			if(currentSession != null) {
				currentSession.afterRequest();
			}
		}
	}
	
	public static class RedisSession implements HttpSession, Serializable {
		private static final long serialVersionUID = 1L;
		private final RequestWrapper requestWrapper;
		private final String id;
		private final long createdAt;
		private volatile long lastAccessedAt;
		private final Map<String, Object> newAttributes = new HashMap<String, Object>();
		private final Set<String> deleteAttribute = new HashSet<String>();
		private final Map<String, Object> dbSession;
	    private volatile boolean invalid;
	    private volatile boolean dirty;
		
		public RedisSession(RequestWrapper request) {
			requestWrapper = request;
			String sessionId = request.getRequestedSessionId();
			if(sessionId!=null && sessionId.length()>0) {
				id = sessionId;
				log.debug("cookie {}={} found", sessionCookieName, id);
			}else {
				id = UUID.randomUUID().toString().replace("-", "");
				addCookie();
			}
			lastAccessedAt = System.currentTimeMillis();
			dbSession = loadDBSession();
			createdAt = dbSession.containsKey(SESSION_CREATE_TIME) ? Long.parseLong(dbSession.remove(SESSION_CREATE_TIME).toString()) : lastAccessedAt;
		}

		private void addCookie() {
			Cookie cookie = new Cookie(sessionCookieName, id);
			cookie.setPath(cookieContextPath);
			if(cookieDomain!=null && cookieDomain.length()>0) {
				cookie.setDomain(cookieDomain);
			}
			cookie.setMaxAge(cookieMaxAge);
			cookie.setSecure(requestWrapper.isSecure());
			cookie.setHttpOnly(true);
			requestWrapper.response.addCookie(cookie);
			log.debug("cookie {}={} create", sessionCookieName, id);
		}

		private Map<String, Object> loadDBSession() {
			try {
				final String sid = prefix+id;
				String session = JedisUtil.exec(pool, sentinel, cluster, new Callback<String>() {
					public String execute(Jedis jedis) {
						return jedis.get(sid);
					}
				});
				if(session!=null && session.length()>0) {
					Map<String, Object> map = gson.deserialize(session);
					log.debug("{} reload", sid);
					return map;
				}else {
					log.debug("{} create", sid);
				}
			}catch(Exception e) {
				log.debug(e.getMessage());
				deleteSession();
			}
			return Collections.emptyMap();
		}

		public void afterRequest() {
			if(invalid) {
				deleteSession();
				deleteCookie();
			}else if(dirty) {
				saveSession();
			}else {
				refreshExpireTime();
			}
		}
		
		private void deleteSession() {
			final String sid = prefix+id;
			Boolean delete = JedisUtil.exec(pool, sentinel, cluster, new Callback<Boolean>() {
				public Boolean execute(Jedis jedis) {
					Long del = jedis.del(sid);
					return 1==del;
				}
			});
			log.debug("{} delete={}", sid, delete);
		}

		private void deleteCookie() {
			boolean remove = false;
			for(Cookie cookie : requestWrapper.getCookies()) {
				if(sessionCookieName.equalsIgnoreCase(cookie.getName())) {
					cookie.setMaxAge(0);
					requestWrapper.response.addCookie(cookie);
					remove = true;
					break;
				}
			}
			log.debug("cookie {}={}, remove={}", sessionCookieName, id, remove);
		}
		
		private void saveSession() {
			final Map<String, Object> snap = new HashMap<String, Object>();
	        snap.putAll(dbSession);
	        snap.putAll(newAttributes);
	        for (String name : this.deleteAttribute) {
	            snap.remove(name);
	        }
	        final String sid = prefix+id;
	        Boolean save = JedisUtil.exec(pool, sentinel, cluster, new Callback<Boolean>() {
				public Boolean execute(Jedis jedis) {
					if(snap.isEmpty()) {
						jedis.del(sid);
						return true;
					}else {
						snap.put(SESSION_CREATE_TIME, String.valueOf(createdAt));
						String value = gson.serialize(snap);
						return "OK".equals(jedis.setex(sid, maxInactiveInterval, value));
					}
				}
			});
	        log.debug("{} save={}", sid, save);
		}
		
		private void refreshExpireTime() {
			final String sid = prefix+id;
			Boolean refresh = JedisUtil.exec(pool, sentinel, cluster, new Callback<Boolean>() {
				public Boolean execute(Jedis jedis) {
					Long ret = jedis.expire(sid, maxInactiveInterval);
					return ret==1;
				}
			});
			log.debug("{} refresh={}", sid, refresh);
		}
		
		public long getCreationTime() {
			return createdAt;
		}

		public String getId() {
			return id;
		}

		public long getLastAccessedTime() {
			return lastAccessedAt;
		}

		public ServletContext getServletContext() {
			return requestWrapper.getServletContext();
		}

		public void setMaxInactiveInterval(int interval) {
			//ignore maxInactiveInterval is from InitParameter
		}

		public int getMaxInactiveInterval() {
			return maxInactiveInterval;
		}

		public Object getAttribute(String name) {
			if (newAttributes.containsKey(name)) {
	            return newAttributes.get(name);
	        }
	        if (deleteAttribute.contains(name)) {
	            return null;
	        }
			return dbSession.get(name);
		}

		public Enumeration<String> getAttributeNames() {
			Set<String> names = new HashSet<String>(dbSession.keySet());
	        names.addAll(newAttributes.keySet());
	        names.removeAll(deleteAttribute);
	        Enumeration<String> enumeration = Collections.enumeration(names);
	        return enumeration;
		}

		public void setAttribute(String name, Object value) {
			if (value != null) {
	            newAttributes.put(name, value);
	            deleteAttribute.remove(name);
	        } else {
	            deleteAttribute.add(name);
	            newAttributes.remove(name);
	        }
	        dirty = true;
		}

		public void removeAttribute(String name) {
	        deleteAttribute.add(name);
	        newAttributes.remove(name);
	        dirty = true;			
		}

		public void invalidate() {
			invalid = true;
	        dirty = true;
		}

		public boolean isNew() {
			return true;
		}
		
		@Deprecated public javax.servlet.http.HttpSessionContext getSessionContext() { return null; }
		@Deprecated public Object getValue(String name) { return null; }
		@Deprecated public String[] getValueNames() { return null; }
		@Deprecated public void putValue(String name, Object value) { }
		@Deprecated public void removeValue(String name) { }
	}
	
	public static class GsonSerializer {
		private final Gson gson;
		private final Type typeToken;
		public GsonSerializer() {
			this.typeToken = new TypeToken<HashMap<String, Object>>() { }.getType();
			this.gson = new GsonBuilder().registerTypeAdapter(typeToken, new MapSerializer()).create();
		}
		public String serialize(Map<String, Object> map) {
            return gson.toJson(map, typeToken);
		}
		public Map<String, Object> deserialize(String str) {
            return gson.fromJson(str, typeToken);
		}
		static class MapSerializer implements JsonSerializer<HashMap<String, Object>>, JsonDeserializer<HashMap<String, Object>> {
			private Gson gson = new Gson();
			public HashMap<String, Object> deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) throws JsonParseException {
				HashMap<String, Object> map = new HashMap<String, Object>();
				JsonObject jsonObject = json.getAsJsonObject();
				Set<Map.Entry<String, JsonElement>> entrySet = jsonObject.entrySet();
				for (Map.Entry<String, JsonElement> entry : entrySet) {
					Object ot = entry.getValue();
					if (ot instanceof JsonPrimitive) {
						map.put(entry.getKey(), ((JsonPrimitive) ot).getAsString());
					} else {
						JsonObject obj = (JsonObject)ot;
						try {
							String fullClassName = obj.remove(FULL_CLASS_NAME).getAsString();
							ot = gson.fromJson(obj, Class.forName(fullClassName));
						} catch (ClassNotFoundException e) {
							log.debug(e.getMessage());
						}
						map.put(entry.getKey(), ot);
					}
				}
				return map;
			}
			public JsonElement serialize(HashMap<String, Object> src, Type typeOfSrc, JsonSerializationContext context) {
				JsonObject json = new JsonObject();
				for(String key : src.keySet()) {
					Object value = src.get(key);
					if(value==null || key==null) {
						continue;
					}
					JsonElement element = context.serialize(value);
					if(element instanceof JsonPrimitive) {
						json.add(key, element);
					}else {
						JsonObject obj = (JsonObject)element;
						obj.addProperty(FULL_CLASS_NAME, value.getClass().getName());
						json.add(key, obj);
					}
				}
				return json;
			}
		}
	}
}
